import os
import clip
import torch
from torchvision.datasets import CIFAR100
from PIL import Image

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

p1 = Image.open("aTensor/Picture1.jpg")
p2 = Image.open("aTensor/Picture2.jpg")
p3 = Image.open("aTensor/Picture3.jpg")
p4 = Image.open("aTensor/Picture4.jpg")
image_input1 = preprocess(p1).unsqueeze(0).to(device)
image_input2 = preprocess(p2).unsqueeze(0).to(device)
image_input3 = preprocess(p3).unsqueeze(0).to(device)
image_input4 = preprocess(p4).unsqueeze(0).to(device)

# Download the dataset
# cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)


text_input1 = clip.tokenize("This image is of a round, golden-crusted apple pie.").to(device)
text_input2 = clip.tokenize("The image is of a homemade apple pie with a lattice crust.").to(device)
text_input3 = clip.tokenize("In the image, there is a slice of apple pie on a white plate.").to(device)
text_input4 = clip.tokenize("A apple pie looks like a pie with a crust and apples inside.").to(device)
text_input5 = clip.tokenize("A photo of bread pudding.").to(device)

text_input = torch.cat([text_input1,text_input2,text_input3,text_input4, text_input5])

# Calculate features
with torch.no_grad():
    image_features1 = model.encode_image(image_input1)
    image_features2 = model.encode_image(image_input2)
    image_features3 = model.encode_image(image_input3)
    image_features4 = model.encode_image(image_input4)

    text_features = model.encode_text(text_input)
image_features = torch.cat([image_features1,image_features2,image_features3,image_features4])
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print(similarity)
# values, indices = similarity[0].topk(5)

# Print the result
# print("\nTop predictions:\n")
# for value, index in zip(values, indices):
#     print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")